Skip to content

Preserve SpeechLM perception checkpoint dtype#15686

Merged
pzelasko merged 12 commits into
NVIDIA-NeMo:mainfrom
DongjiGao:speechlm-perception-checkpoint-dtype
May 13, 2026
Merged

Preserve SpeechLM perception checkpoint dtype#15686
pzelasko merged 12 commits into
NVIDIA-NeMo:mainfrom
DongjiGao:speechlm-perception-checkpoint-dtype

Conversation

@DongjiGao
Copy link
Copy Markdown
Contributor

@DongjiGao DongjiGao commented May 11, 2026

Summary

  • Stop forcing the SpeechLM vLLM perception module to FP32 when loading checkpoint weights.
  • Keep raw audio input/preprocessing in FP32, then cast processed features to the encoder checkpoint dtype before encoder execution.
  • Cast final audio embeddings to the initialized LLM dtype before inserting them into the language model stream.

Test plan

  • python3 -c "import ast; ast.parse(open('/home/dongjig/NeMo_merge/nemo/collections/speechlm2/vllm/salm/model.py').read()); ast.parse(open('/home/dongjig/NeMo_merge/nemo/collections/speechlm2/modules/perception.py').read())"
  • VoxPopuli full local vLLM plugin check on NemotronH SpeechLM: WER 9.07, RTFx 814.54 at /data/dongjig/results/quantization/speechlm_bf16_perception_checkpoint_dtype_voxpopuli_20260511_095950/result.json.
  • ASR leaderboard dtype comparison completed for LibriSpeech clean/other, TEDLIUM, SPGISpeech, VoxPopuli, GigaSpeech, and first 1024 Earnings22 samples under /data/dongjig/results/quantization/speechlm_leaderboard_perception_dtype_20260510_150218; no meaningful WER regression observed for BF16 perception.

Avoid forcing the SpeechLM audio perception module to FP32 during vLLM inference so BF16 checkpoints can run the encoder in their stored dtype while keeping raw audio preprocessing in FP32.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 11, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@DongjiGao DongjiGao requested a review from pzelasko May 11, 2026 17:22
DongjiGao added 2 commits May 11, 2026 10:34
Move the processed-feature dtype cast out of the shared perception module and into the SpeechLM vLLM model path so this fix remains scoped to plugin inference.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Keep the dtype conversion scoped to the SpeechLM vLLM plugin path and leave the shared perception module unchanged.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
@DongjiGao DongjiGao force-pushed the speechlm-perception-checkpoint-dtype branch from e369a20 to 36c35e2 Compare May 11, 2026 17:40
Copy link
Copy Markdown
Collaborator

@pzelasko pzelasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great fix but I think there is too much defensive dtype casting, can we minimize to the absolutely necessary ones only?

Comment thread nemo/collections/speechlm2/vllm/salm/model.py Outdated
Comment thread nemo/collections/speechlm2/vllm/salm/model.py Outdated
Comment thread nemo/collections/speechlm2/vllm/salm/model.py Outdated
DongjiGao added 3 commits May 11, 2026 11:08
Call the audio preprocessor directly before casting features to the perception encoder dtype, keeping the dtype fix scoped to the plugin inference path.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Keep the perception module in the checkpoint dtype while loading the original tensors directly.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Preserve the existing BF16 LLM boundary cast and keep the PR focused on avoiding FP32 perception weights.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Comment thread nemo/collections/speechlm2/vllm/salm/model.py Outdated
DongjiGao added 3 commits May 11, 2026 12:07
Keep raw audio preprocessing in FP32 and run perception in BF16 for the vLLM plugin path without extra defensive dtype detection.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Perception outputs already follow the plugin perception dtype, so avoid an extra cast before returning audio embeddings.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
Rely on AudioPerceptionModule to handle preprocessing and encoder handoff after the plugin sets the perception module dtype.

Signed-off-by: Dongji Gao <dongjig@nvidia.com>
@pzelasko
Copy link
Copy Markdown
Collaborator

/ok to test 82fa4af

@pzelasko pzelasko enabled auto-merge (squash) May 11, 2026 20:32
@pzelasko
Copy link
Copy Markdown
Collaborator

/ok to test 8653bac

@DongjiGao
Copy link
Copy Markdown
Contributor Author

/ok to test 83974ea

@github-actions
Copy link
Copy Markdown
Contributor

[🤖]: Hi @DongjiGao 👋,

We wanted to let you know that a CICD pipeline for this PR just finished successfully.

So it might be time to merge this PR or get some approvals.

@pzelasko pzelasko merged commit 5a855e7 into NVIDIA-NeMo:main May 13, 2026
153 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants